import os
import sys
import subprocess
import logging
from typing import Literal
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import polars as pl
import seaborn as sns

from Classes import FileManagement, Gender
from myutil import small_tools

# matplotlib.use('Agg')
logger = small_tools.create_logger("MainLogger", level=logging.WARNING)


def missing(
    fm: FileManagement,
    input_name: str,
    save_path_name: str,
    ethnic: str | None = None,
    gender: Literal["Men", "Women"] | None = None
) -> None:
    """
    Visualise missing data proportions.

    # Args:

    **input_name** (str): _Name of the input file._

    **save_path_name** (str): _Path to save the visualisations (without suffix and extension)._

    # Returns:

    **None**

    # Generate Files:

    **${save_path_name}.imiss**: _Missingness of individuals._

    **${save_path_name}.lmiss**: _Missingness of SNPs._

    **${save_path_name}_imiss_visualisation.png**: _Visualisation of missingness of individuals._

    **${save_path_name}_lmiss_visualisation.png**: _Visualisation of missingness of SNPs._
    """
    logger.info(
        "Calculating proportions of missing data of %s %s data set", ethnic, gender)
    try:
        command = [
            fm.plink,
            "--bfile", input_name,
            "--missing",
            "--out", input_name
        ]
        subprocess.run(
            command,
            stdout=subprocess.DEVNULL,
            stderr=None,
            check=True
        )
    except subprocess.CalledProcessError as e:
        logger.error(
            f"Error calculating proportions of missing data plink: {e}")
        sys.exit(2)
    logger.info("Finished Calculation")

    logger.info("Visualising missing data")

    # Visualise missingness of individuals.
    logger.info("Visualising missingness of individuals")
    imiss_df = pd.read_csv(f"{input_name}.imiss",
                           sep=r"\s+", usecols=lambda col: col in ["F_MISS"])

    plt.hist(imiss_df, density=True, bins=20)
    # density=True makes bin's raw count divided by the total number of counts and the bin width,
    # so that the area under the histogram integrates to 1.
    plt.title(f"Histogram of SNP missingness per individual from {
              ethnic} {gender} data set")
    plt.xlabel("Individuals' SNP missing rate")
    plt.ylabel("Frequency / Intercept")
    # Remove whitespace and avoid overlap around the plot.
    plt.tight_layout()
    plt.savefig(f"{save_path_name}_imiss.png", dpi=300)
    # clear figure to prevent conflicts.
    plt.clf()

    # Visualise missingness of SNPs.
    lmiss_df = pd.read_csv(f"{input_name}.lmiss",
                           engine="c", sep=r"\s+", usecols=lambda col: col in ["F_MISS"])

    plt.hist(lmiss_df, density=True, bins=20)
    plt.title(f"Histogram of individual missingness per SNP from {
              ethnic} {gender} data set")
    plt.xlabel("SNPs' individual missing rate")
    plt.ylabel("Frequency / Intercept")
    plt.tight_layout()
    plt.savefig(f"{save_path_name}_lmiss.png", dpi=300)
    plt.clf()


def hardy_weinberg(
    fm: FileManagement,
    input_name: str,
    save_path_name: str,
    ethnic: str | None = None,
    gender: Gender = Gender.UNKNOWN
) -> None:
    """
    Visualise p-value generated by Hardy Weinberg Equilibrium.

    Args:
        **input_name** (str): _Name of the input file._
        **save_path_name** (str): _Path to save the visualisations (without suffix and extension)._
        **ethnic** (str): _Ethnic group of the data set._
        **gender** (str): _Gender of the data set._

    Generate Files:
        **${save_path_name}.hwe**: _p-value and other statistical information of Hardy Weinberg Equilibrium._
        **${save_path_name}.png**: _Visualisation of p-value of HWE._
    """
    # Calculate HWE.
    logger.info("Calculating p-value of HWE")
    try:
        command = [
            fm.plink,
            "--bfile", input_name,
            "--hardy",
            "--out", input_name
        ]
        subprocess.run(
            command,
            stdout=subprocess.DEVNULL,
            stderr=None,
            check=True
        )
    except subprocess.CalledProcessError as e:
        logger.error(f"Error running plink: {e}")
        return
    logger.info("Finished Calculation")

    # Draw histogram.
    hwe = pd.read_csv(f"{input_name}.hwe", sep=r"\s+",
                      usecols=lambda col: col in ["P"])
    plt.hist(hwe, bins=10)
    plt.xlabel("p value")
    plt.ylabel("Frequency / Intercept")
    plt.title(f"Histogram of HWE from {ethnic} {gender} data set")
    plt.tight_layout()
    plt.savefig(f"{save_path_name}.png", dpi=300)

    plt.clf()
    '''
    # Zoomed version. Focusing on the severe deviating SNPs.
    zoomed = pd.read_csv(f"{input_name}_zoomhwe.csv", sep="\\s+", engine="c", header=None, usecols=[8])
    plt.hist(zoomed, bins=20)
    plt.xlabel("p value")
    plt.ylabel("Frequency / Intercept")
    plt.title(f"Histogram of HWE from {ethnic} {gender} data set: severely deviating SNPs only")
    plt.savefig(f"{save_path_name}_zoomhwe.png", dpi=300)
    '''
    plt.clf()


def minor_allele_frequency(
    fm: FileManagement,
    input_name: str,
    save_path_name: str,
    /,
    ethnic: str | None = None,
    gender: Gender = Gender.UNKNOWN
) -> None:
    """
    Visualise minor allele frequency.

    Args:
        **input_name** (str): _Name of the input file._
        **save_path_name** (str): _Path to save the visualisations (without suffix and extension)._
        **ethnic** (str): _Ethnic group of the data set._
        **gender** (str): _Gender of the data set._
    """
    # print("Launched", input_name, save_path_name, ethnic, gender)
    try:
        command = [
            fm.plink,
            "--bfile", input_name,
            "--freq",
            "--out", input_name
        ]
        subprocess.run(
            command,
            stdout=subprocess.DEVNULL,
            stderr=None,
            check=True
        )
    except subprocess.CalledProcessError as e:
        logger.error(f"Error running plink: {e}")
        return
    except Exception as e:
        logger.error(f"Unexpected error occurred: {e}")
        return
    logger.info("Finished Calculation")

    freq_file = pd.read_csv(
        f"{input_name}.frq", sep=r"\s+", engine="c", usecols=lambda col: col in ["MAF"]
    )

    plt.hist(freq_file["MAF"], bins=20)
    plt.title(f"MAF check of {ethnic} {gender} data set")
    plt.tight_layout()
    plt.savefig(f"{save_path_name}.png", dpi=300)
    plt.clf()
    return


# 修改了关联性分析可视化的代码，能简单的在服务器上运行
def assoc_visualisation(
    file_path: str,
    output_path: str,
    gender: Gender,
    ethnic: str,
    phenotype: str,
    n: int | None = None,
    alpha: float = 0.05,
):
    """Visualise association analysis result.

    Args:
        file_path (str): Path to the association result file (with .qassoc or .assoc extension).
        output_path (str): Path to save the visualisations (without file extension).
        gender (Gender): Gender of the sample.
        ethnic (str): Ethnic group of the sample.
        phenotype (str): Phenotype name.
        n (int | None, optional): Number of samples to perform Bonferroni's correction. Defaults to None.
        alpha (float, optional): Significance threshold. Defaults to 0.05.
    """
    mpl.use("Agg")  # Use non-interactive backend for matplotlib
    try:
        logger.info("开始关联性分析可视化")

        if not os.path.exists(file_path):
            logger.warning(f"{file_path} does not exist!")
            return
        file_name = os.path.basename(file_path)

        logger.debug("读取文件")
        # calculate threshold
        a_m = pd.read_csv(file_path, engine="c",
                          sep=r"\s+", usecols=lambda col: col in ["SNP", "P"])
        a_m["ID"] = list(range(a_m.shape[0]))
        threshold = alpha / a_m.shape[0] if n is None else alpha / n

        # colors = list(range(a_m.shape[0]))
        x = a_m["ID"]
        y = -np.log10(a_m["P"])

        t_pd = a_m.loc[a_m["P"] < threshold, :]
        # print(t_pd)

        sns.set_theme("paper", style="whitegrid")
        # 曼哈顿图 (Manhattan Plot)
        logger.debug("开始绘制曼哈顿图")
        # - plt.style.use('ggplot')  # 设置类似 Seaborn 的样式
        plt.figure(figsize=(10, 5), dpi=300)
        plt.scatter(x, y, s=2, c=y, cmap='viridis', marker="o")
        # 添加水平线
        plt.axhline(y=-np.log10(threshold), color='red', linestyle='--')

        # 添加标注
        for _, row_ in t_pd.iterrows():
            plt.text(row_["ID"], -np.log10(row_["P"]), row_["SNP"],
                     rotation=30, fontsize=10, ha='left', va='bottom')  # 调整文本对齐方式

        plt.title(
            f"Manhattan Plot of Assoc Result of {
                  ethnic} {gender} on {phenotype}"
        )
        plt.colorbar(label=r'$-\log_{10}P-value$')

        # 调整边界
        plt.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1)
        plt.tight_layout()  # 自动调整布局以避免重叠和超出画面
        plt.savefig(f"{output_path}_Manhattan.png", dpi=600)
        plt.close()

        # QQ图 (QQ-Plot)
        logger.debug("开始绘制QQ图")
        sns.set_theme("paper", style="whitegrid")
        plt.figure(figsize=(5, 5), dpi=300)
        # plt.style.use('dark_background')  # 设置类似 Seaborn 的黑色背景样式

        # 理论 -log10(P) 值
        x = np.linspace(0.5 / a_m.shape[0], 1 -
                        0.5 / a_m.shape[0], a_m.shape[0])
        sorted_p_values = -np.log10(a_m["P"].sort_values(ascending=True))

        plt.scatter(-np.log10(x), sorted_p_values, marker="^",
                    facecolors="none", edgecolors="b")

        # 添加 y=x 的参考线
        max_val = max(-np.log10(x).max(), sorted_p_values.max())
        plt.plot([0, max_val], [0, max_val], color="#E53528", lw=1)

        plt.xlabel(r"Theoretical $-\log_{10}P$ Value")  # r 表示原始字符串，防止转义字符的问题
        plt.ylabel(r"Observed $-\log_{10}P$ Value")
        plt.title(
            f"QQ-Plot of Assoc Result of {ethnic} {gender} on {phenotype}")
        plt.tight_layout()
        plt.savefig(f"{output_path}_QQ.png", dpi=600)
        plt.close()
        logger.debug(
            f'已输出至："{file_name}_QQ.png" 和 "{
                     file_name}_Manhattan.png"中'
        )
    except Exception as e:
        logger.error(f"Error: {e}")


def assoc_mperm_visualisation(
    file_path: str,
    output_path: str,
    /,
    *,
    gender: Gender,
    ethnic_name: str,
    phenotype_name: str,
    n: int,
    alpha: float = 0.05,
):
    """
    Visualise `plink --assoc mperm=<int>` result.

    Two Manhattan plot and one QQ plot will be generated.

    Manhattan plot:

        One shows P values of permutation tests; SNPs with P value greater than
        alpha will be marked.

        The other shows P values of the assoc results. Threshold is alpha / n .
    
    Parameters:
        file_path (str):
            Path to the `plink --assoc mperm=<int>` result (**without .qassoc extension**). 
            The association result is expected to be in `.qassoc` format.
        output_path (str):
            Path to the output file (without file extension).
        gender (Gender):
            Gender of the sample.
        ethnic_name (str):
            ethnicity
        phenotype_name (str):
            Name of the phenotype.
        n (int):
            Number of samples, used for Bonferrini's correction.
        alpha (float):
            Significance threshold.
    """
    mpl.use("Agg")  # Use non-interactive backend for matplotlib
    logger.info("Start visualising `--assoc mperm=<int>` result")

    file_path = f"{file_path}.qassoc"
    mperm_path = f"{file_path}.mperm"

    if not os.path.exists(mperm_path):
        raise FileNotFoundError(f"{file_path}.mperm not found.")

    res_df: pl.DataFrame = pl.from_pandas(
        pd.read_csv(file_path, sep=r"\s+", usecols=["SNP", "P"])  # type: ignore
    )
    perm_df: pl.DataFrame = pl.from_pandas(
        pd.read_csv(mperm_path, sep=r"\s+", usecols=["SNP", "EMP2"])  # type: ignore
    )
    concat_df = res_df.join(perm_df, on="SNP", how="inner").with_row_index()
    mperm_positive_df = concat_df.filter(pl.col("EMP2") < alpha)
    assoc_positive_df = concat_df.filter(pl.col("P") < alpha / n)

    # Manhattan Plot: permutation test
    logger.debug("Plotting Manhattan plot: permutation test")
    sns.set_theme("paper", style="white")
    plt.figure(figsize=(10, 5), dpi=300)

    # scatter
    plt.scatter(
        concat_df["index"],
        -concat_df["EMP2"].log10(),
        s=2,
        c=-concat_df["EMP2"].log10(),
        cmap="viridis",
        marker="o",
    )
    # horizontal line
    plt.axhline(y=-np.log10(alpha), color="red", linestyle="--")
    # annotation
    for row in mperm_positive_df.iter_rows(named=True):
        plt.text(
            row["index"],
            -np.log10(row["EMP2"]),
            row["SNP"],
            rotation=30,
            fontsize=10,
            ha="left",
            va="bottom",
        )

    plt.title(
        f"Manhattan Plot of Permutation Test Result of {ethnic_name}, {
        gender} on {phenotype_name}"
    )
    plt.colorbar(label=r"$-log_{10}P$")
    # save
    plt.tight_layout()
    plt.savefig(f"{output_path}_Manhattan_mperm.png", dpi=600)
    plt.close()

    # Manhattan Plot: original assoc
    logger.debug("Plotting Mantattan plot: original association result")
    sns.set_theme("paper", style="white")
    plt.figure(figsize=(10, 5))

    # scatter
    plt.scatter(
        concat_df["index"],
        -concat_df["P"].log10(),
        s=2,
        c=-concat_df["P"].log10(),
        cmap="viridis",
        marker="o",
    )
    # horizontal line
    plt.axhline(y=-np.log10(alpha / n), color="red", linestyle="--")
    # annotation
    for row in assoc_positive_df.iter_rows(named=True):
        plt.text(
            row["index"],
            -np.log10(row["P"]),
            row["SNP"],
            rotation=30,
            fontsize=10,
            ha="left",
            va="bottom",
        )

    plt.title(
        f"Manhattan Plot of Association Result of {
              ethnic_name}, {gender} on {phenotype_name}"
    )
    plt.colorbar(label=r"$-log_{10}P$-value")
    # save
    plt.tight_layout()
    plt.savefig(f"{output_path}_Manhattan_assoc.png", dpi=600)
    plt.close()

    # QQ plot
    logger.debug("Plotting QQ plot")
    sns.set_theme("paper", style="white")
    plt.figure(figsize=(5, 5), dpi=300)
    # Theoretical -log10(P) value
    snp_num = concat_df.shape[0]
    x = np.linspace(0.5 / snp_num, 1 - 0.5 / snp_num, snp_num)
    sorted_p_values = -np.log10(concat_df["P"].sort(descending=False))
    # Scatter
    plt.scatter(
        -np.log10(x), sorted_p_values, marker="^", facecolors="none", edgecolors="b"
    )
    # Add y=x line
    plt.axline((0, 0), (1, 1), color="#E53528", lw=1)
    # Label and title
    plt.xlabel(r"Theoretical $-log_{10}P$ value")
    plt.ylabel(r"Observed $-log_{10}P$ value")
    plt.title(
        f"QQ-Plot of Association Result of {ethnic_name}, {gender} on {phenotype_name}"
    )
    # save
    plt.tight_layout()
    plt.savefig(f"{output_path}_QQ.png", dpi=600)
    plt.close()
    logger.debug(
        f'\
已输出至："{output_path}_QQ.png" 和 "{output_path}_Manhattan_mperm.png" 和 \
"{output_path}_Manhattan_assoc.png"中'
    )
